#----------------------------------------------------------------------
#  Pointcloud generation using the Octree technique
#  Author: Andrea Pavan
#  Date: 18/12/2022
#  License: GPLv3-or-later
#----------------------------------------------------------------------
using ElasticArrays;
using PyPlot;
include("utils.jl");


#problem definition
l1 = 10.0;                  #domain x size
l2 = 5.0;                   #domain y size
l3 = 2.0;                   #domain z size
meshSize = 1.5;             #distance target between internal nodes
surfaceMeshSize = 0.75;     #distance target between boundary nodes


#surface points
time1 = time();
pointcloud = ElasticArray{Float64}(undef,3,0);      #3xN matrix containing the coordinates [X;Y;Z] of each node
boundaryNodes = Vector{Int};        #indices of the boundary nodes
normals = ElasticArray{Float64}(undef,3,0);     #3xN matrix containing the components [nx;ny;nz] of the normal of each boundary node
(section,sectionnormals) = defaultCrossSection(l2, l3, surfaceMeshSize);
#section[1,:] .+= l2/2;
for x in 0:surfaceMeshSize:l1
    append!(pointcloud, vcat(zeros(Float64,1,size(section,2)).+x,section));
    append!(normals, vcat(zeros(Float64,1,size(sectionnormals,2)),sectionnormals));
end
for y in -l2/2+surfaceMeshSize:surfaceMeshSize:l2/2-surfaceMeshSize
    for z in -l3+surfaceMeshSize:surfaceMeshSize:0-surfaceMeshSize
        if abs(y)<(l2-l3)/2 || (abs(y)-(l2-l3)/2)^2+(z+l3/2)^2<(l3/2)^2
            append!(pointcloud, [0,y,z]);
            append!(normals, [-1,0,0]);
            append!(pointcloud, [l1,y,z]);
            append!(normals, [1,0,0]);
        end
    end
end
boundaryNodes = collect(range(1,size(pointcloud,2)));


#=
#internal points - cartesian
for y in -l2/2:meshSize:l2/2
    for z in -l3:meshSize:0
        if abs(y)<(l2-l3)/2 || (abs(y)-(l2-l3)/2)^2+(z+l3/2)^2<(l3/2)^2
            for x in 0:meshSize:l1
                #append!(pointcloud, [x,y,z]);
                append!(pointcloud, [x,y,z]+(rand(Float64,3).-0.5).*meshSize/5);
            end
        end
    end
end
=#

#internal points - build octree
println("Building octree...");
(octree,octreeSize,octreeCenter,octreePoints,octreeNpoints) = buildOctree(pointcloud);
println("Octree generated");
octreeSizeMap = Vector{Float64}(undef,0);

#extra-refining at the boundaries
#=octreeRefining = findall(octreeNpoints.==1);
for i in octreeRefining
    divideOctreeCell!(i,octree,octreeSize,octreeCenter,octreePoints,octreeNpoints);
end=#
octreeLeaves = findall(octreeNpoints.>=0);      #indices of the cells that are not divided


#balacing octree
#=println("Balancing octree...");
function searchOctreeCell(P)
    for i in octreeLeaves
        if all(abs.(P-octreeCenter[:,i]).<=octreeSize[i]/2)
            return i;
        end
    end
    return 0;
end

cellsToDivide = Vector{Int}(undef,0);
for j in octreeLeaves
    #check the neighbors size
    vPoints = [octreeCenter[:,j]+[octreeSize[j],octreeSize[j],octreeSize[j]],
            octreeCenter[:,j]+[octreeSize[j],octreeSize[j],-octreeSize[j]],
            octreeCenter[:,j]+[octreeSize[j],-octreeSize[j],octreeSize[j]],
            octreeCenter[:,j]+[octreeSize[j],-octreeSize[j],-octreeSize[j]],
            octreeCenter[:,j]+[-octreeSize[j],octreeSize[j],octreeSize[j]],
            octreeCenter[:,j]+[-octreeSize[j],octreeSize[j],-octreeSize[j]],
            octreeCenter[:,j]+[-octreeSize[j],-octreeSize[j],octreeSize[j]],
            octreeCenter[:,j]+[-octreeSize[j],-octreeSize[j],-octreeSize[j]]];
    neighborIdx = Vector{Int}(undef,8);
    for k=1:8
        neighborIdx[k] = searchOctreeCell(vPoints[k]);
    end
    unique!(neighborIdx);
    neighborIdx = neighborIdx[findall(neighborIdx.!=0)];
    for idx in neighborIdx
        if octreeSize[idx]>2.1*octreeSize[j]
            #divide neighbor
            push!(cellsToDivide,idx);
            #divide!(idx);
        end
    end
end
println("Octree balanced");=#


#generate pointcloud
for i in octreeLeaves
    if octreeNpoints[i] != 1
        #candidatePoint = octreeCenter[:,i];
        candidatePoint = octreeCenter[:,i] + (rand(Float64,3).-0.5).*meshSize/5;
        if (abs(candidatePoint[2])<(l2-l3)/2 && candidatePoint[3]<=-0.1) || (abs(candidatePoint[2])-(l2-l3)/2)^2+(candidatePoint[3]+l3/2)^2<(l3/2-0.1)^2
            append!(pointcloud, candidatePoint);
            append!(octreeSizeMap, octreeSize[i]);
        end
    end
end

internalNodes = collect(range(1+length(boundaryNodes),size(pointcloud,2)));
println("Generated pointcloud in ", round(time()-time1,digits=2), " s");
println("Pointcloud properties:");
println("  Boundary nodes: ",length(boundaryNodes));
println("  Internal nodes: ",length(internalNodes));
println("  Memory: ",memoryUsage(pointcloud,boundaryNodes));


#pointcloud plot
figure();
#scatter3D(pointcloud[1,internalNodes],pointcloud[2,internalNodes],pointcloud[3,internalNodes],c=log10.(octreeSizeMap),cmap="viridis");
plot3D(pointcloud[1,internalNodes],pointcloud[2,internalNodes],pointcloud[3,internalNodes],"k.");
plot3D(pointcloud[1,boundaryNodes],pointcloud[2,boundaryNodes],pointcloud[3,boundaryNodes],"r.");
title("Pointcloud plot");
axis("equal");
display(gcf());

#cross-section plot
figure();
plot(pointcloud[2,internalNodes],pointcloud[3,internalNodes],"k.");
plot(pointcloud[2,boundaryNodes],pointcloud[3,boundaryNodes],"r.");
title("Cross-section plot");
axis("equal");
display(gcf());